// proxy.cpp
#define _WIN32_DCOM
#include <iostream.h>
#include "Component\component.h"
#include "registry.h"

CLSID CLSID_InsideCOMProxy = {0x10000004,0x0000,0x0000,{0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x01}};

long g_cComponents = 0;
long g_cServerLocks = 0;

class CInsideCOM : public ISum, public IMarshal
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void **ppv);

	// IMarshal
	HRESULT __stdcall GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, CLSID* pClsid);
	HRESULT __stdcall GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, DWORD* pSize);
	HRESULT __stdcall MarshalInterface(IStream* pStream, REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags);
	HRESULT __stdcall DisconnectObject(DWORD dwReserved);
	HRESULT __stdcall UnmarshalInterface(IStream* pStream, REFIID riid, void** ppv);
	HRESULT __stdcall ReleaseMarshalData(IStream* pStream);

	// ISum
	HRESULT __stdcall Sum(int x, int y, int *sum);

	CInsideCOM() : m_cRef(1) { g_cComponents++; }
	~CInsideCOM() { cout << "Proxy: CInsideCOM::~CInsideCOM()" << endl, g_cComponents--; }

private:
	ULONG m_cRef;
	HANDLE hFileMap;
	void* pMem;
	HANDLE hStubEvent;
	HANDLE hProxyEvent;
};

HRESULT CInsideCOM::GetUnmarshalClass(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, CLSID* pClsid)
{
	cout << "GetUnmarshalClass" << endl;
    return E_NOTIMPL;
}

HRESULT CInsideCOM::GetMarshalSizeMax(REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags, DWORD* pSize)
    {
	cout << "GetMarshalSize" << endl;
    return E_NOTIMPL;
    }

HRESULT CInsideCOM::MarshalInterface(IStream* pStream, REFIID riid, void* pv, DWORD dwDestContext, void* pvDestContext, DWORD dwFlags)
    {
	cout << "MarshalInterface" << endl;
    return E_NOTIMPL;
    }

HRESULT CInsideCOM::DisconnectObject(DWORD dwReserved)
    {
	cout << "DisconnectObject" << endl;
    return E_NOTIMPL;
    }

HRESULT CInsideCOM::UnmarshalInterface(IStream* pStream, REFIID riid, void** ppv)
    {
	unsigned long num_read;
	char buffer_to_read[255];
	char* pszFileMapName;
	char* pszStubEventName;
	char* pszProxyEventName;

    pStream->Read((void *)buffer_to_read, 255, &num_read);

	cout << "Proxy: CInsideCOM::UnmarshalInterface() = " << buffer_to_read << endl;

	pszFileMapName = strtok(buffer_to_read, ",");
	pszStubEventName = strtok(NULL, ",");
	pszProxyEventName = strtok(NULL, ",");

	hFileMap = OpenFileMapping(FILE_MAP_WRITE, FALSE, pszFileMapName);
	pMem = MapViewOfFile(hFileMap, FILE_MAP_WRITE, 0, 0, 0);

	hStubEvent = OpenEvent(EVENT_MODIFY_STATE, FALSE, pszStubEventName);
	hProxyEvent = OpenEvent(EVENT_MODIFY_STATE|SYNCHRONIZE, FALSE, pszProxyEventName);

    //Get the pointer to return to the client
    return QueryInterface(riid, ppv);
    }

HRESULT CInsideCOM::ReleaseMarshalData(IStream* pStream)
    {
	cout << "ReleaseMarshalData" << endl;
    return S_OK;
    }

HRESULT CInsideCOM::QueryInterface(REFIID riid, void **ppv)
{
	if(riid == IID_IUnknown)
	{
		cout << "Proxy: CInsideCOM::QueryInterface() for IUnknown" << endl;
		*ppv = reinterpret_cast<IUnknown*>(this);
	}
	else if(riid == IID_IMarshal)
	{
		cout << "Proxy: CInsideCOM::QueryInterface() for IMarshal" << endl;
		*ppv = (IMarshal*)this;
	}
	else if(riid == IID_ISum)
	{
		cout << "Proxy: CInsideCOM::QueryInterface() for ISum" << endl;
		*ppv = (ISum*)this;
	}
	else 
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

ULONG CInsideCOM::AddRef()
{
	cout << "Proxy: CInsideCOM::AddRef() m_cRef = " << m_cRef + 1 << endl;
	return ++m_cRef;
}

ULONG CInsideCOM::Release()
{
	cout << "Proxy: CInsideCOM::Release() m_cRef = " << m_cRef - 1 << endl;
	if(--m_cRef != 0)
		return m_cRef;
	short method_id = 1; // ISum::Release
	memcpy(pMem, &method_id, sizeof(short));
	cout << "Proxy Setting event to destroy object " << endl;
	SetEvent(hStubEvent);
	delete this;
	return 0;
}

struct SumTransmit
{
	int x;
	int y;
	int sum;
};

HRESULT CInsideCOM::Sum(int x, int y, int *sum)
{
	SumTransmit s;
	s.x = x;
	s.y = y;
	short method_id = 2; // ISum::Sum

	memcpy(pMem, &method_id, sizeof(short));
	memcpy((short*)pMem+1, &s, sizeof(SumTransmit));

	cout << "Proxy: Going to call component CInsideCOM::Sum() " << x << " + " << y << endl;

	SetEvent(hStubEvent);
	WaitForSingleObject(hProxyEvent, INFINITE);

	memcpy(&s, pMem, sizeof(s));
	
	*sum = s.sum;
	return S_OK;
}

class CFactory : public IClassFactory
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// IClassFactory
	HRESULT __stdcall CreateInstance(IUnknown *pUnknownOuter, REFIID riid, void** ppv);
	HRESULT __stdcall LockServer(BOOL bLock);

	CFactory() : m_cRef(1) { };
	~CFactory() { };
private:
	ULONG m_cRef;
};

ULONG CFactory::AddRef()
{
	cout << "Proxy: CFactory::AddRef() m_cRef = " << m_cRef + 1 << endl;
	return ++m_cRef;
}

ULONG CFactory::Release()
{
	cout << "Proxy: CFactory::Release() m_cRef = " << m_cRef - 1 << endl;
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CFactory::QueryInterface(REFIID riid, void **ppv)
{
	if((riid == IID_IUnknown) || (riid == IID_IClassFactory))
	{
		cout << "Proxy: CFactory::QueryInteface() for IUnknown or IClassFactory " << this << endl;
		*ppv = (IClassFactory *)this;
	}
	else 
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT CFactory::CreateInstance(IUnknown *pUnknownOuter, REFIID riid, void **ppv)
{
	if(pUnknownOuter != NULL)
		return CLASS_E_NOAGGREGATION;

	CInsideCOM *pInsideCOM = new CInsideCOM;
	cout << "Proxy: CFactory::CreateInstance() " << pInsideCOM << endl;

	if(pInsideCOM == NULL)
		return E_OUTOFMEMORY;

	HRESULT hr = pInsideCOM->QueryInterface(riid, ppv);
	pInsideCOM->Release();
	return hr;
}

HRESULT CFactory::LockServer(BOOL bLock)
{
	if(bLock)
		g_cServerLocks++;
	else
		g_cServerLocks--;
	return S_OK;
}

HRESULT __stdcall DllCanUnloadNow()
{
	cout << "Proxy: DllCanUnloadNow() " << (g_cServerLocks == 0 && g_cComponents == 0 ? "Yes" : "No") << endl;
	if(g_cServerLocks == 0 && g_cComponents == 0)
		return S_OK;
	else
		return S_FALSE;
}

HRESULT __stdcall DllGetClassObject(REFCLSID clsid, REFIID riid, void** ppv)
{
	cout << "Proxy: DllGetClassObject" << endl;
	
	if(clsid != CLSID_InsideCOMProxy)
		return CLASS_E_CLASSNOTAVAILABLE;

	CFactory* pFactory = new CFactory;
	if(pFactory == NULL)
		return E_OUTOFMEMORY;

	// QueryInterface for IClassFactory
	HRESULT hr = pFactory->QueryInterface(riid, ppv);
	pFactory->Release();
	return hr;
}

HRESULT __stdcall DllRegisterServer()
{
	return RegisterServer("proxy.dll", CLSID_InsideCOMProxy, "Inside COM Sample #1", "Component.InsideCOM", "Component.InsideCOM.1", NULL);
}

HRESULT __stdcall DllUnregisterServer()
{
	return UnregisterServer(CLSID_InsideCOMProxy, "Component.InsideCOM", "Component.InsideCOM.1");
}

BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved)
{
//	cout << "Proxy: DllMain()" << endl;
	return TRUE;
}